{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预热阶段"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31mSignature:\u001b[39m prerequisite_optimize(mod, params=\u001b[38;5;28;01mNone\u001b[39;00m)\n",
      "\u001b[31mSource:\u001b[39m   \n",
      "\u001b[38;5;28;01mdef\u001b[39;00m prerequisite_optimize(mod, params=\u001b[38;5;28;01mNone\u001b[39;00m):\n",
      "    \u001b[33m\"\"\"Prerequisite optimization passes for quantization. Perform\u001b[39m\n",
      "\u001b[33m    \"SimplifyInference\", \"FoldScaleAxis\", \"FoldConstant\", and\u001b[39m\n",
      "\u001b[33m    \"CanonicalizeOps\" optimization before quantization.\"\"\"\u001b[39m\n",
      "    optimize = tvm.transform.Sequential(\n",
      "        [\n",
      "            _transform.SimplifyInference(),\n",
      "            _transform.FoldConstant(),\n",
      "            _transform.FoldScaleAxis(),\n",
      "            _transform.CanonicalizeOps(),\n",
      "            _transform.FoldConstant(),\n",
      "        ]\n",
      "    )\n",
      "\n",
      "    \u001b[38;5;28;01mif\u001b[39;00m params:\n",
      "        mod[\u001b[33m\"main\"\u001b[39m] = _bind_params(mod[\u001b[33m\"main\"\u001b[39m], params)\n",
      "\n",
      "    mod = optimize(mod)\n",
      "    \u001b[38;5;28;01mreturn\u001b[39;00m mod\n",
      "\u001b[31mFile:\u001b[39m      /media/pc/data/board/arria10/lxw/tasks/tvm-ai/python/tvm/relay/quantize/quantize.py\n",
      "\u001b[31mType:\u001b[39m      function"
     ]
    }
   ],
   "source": [
    "import set_env\n",
    "from tvm.relay.quantize.quantize import prerequisite_optimize\n",
    "prerequisite_optimize??"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下列模型为例展示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%data: Tensor[(1, 3, 4, 4), float32] /* span=aten___convolution_0_data:0:0 */, %aten___convolution_0_weight: Tensor[(16, 3, 3, 3), float32] /* span=aten___convolution_0_weight:0:0 */, %aten___convolution_0_bias: Tensor[(16), float32] /* span=aten___convolution_0_bias:0:0 */, %aten__batch_norm_0_weight: Tensor[(16), float32] /* span=aten__batch_norm_0_weight:0:0 */, %aten__batch_norm_0_bias: Tensor[(16), float32] /* span=aten__batch_norm_0_bias:0:0 */, %aten__batch_norm_0_mean: Tensor[(16), float32] /* span=aten__batch_norm_0_mean:0:0 */, %aten__batch_norm_0_var: Tensor[(16), float32] /* span=aten__batch_norm_0_var:0:0 */) {\n",
      "  %0 = nn.conv2d(%data, %aten___convolution_0_weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* span=aten___convolution_0:0:0 */;\n",
      "  %1 = nn.bias_add(%0, %aten___convolution_0_bias) /* span=aten___convolution_0:0:0 */;\n",
      "  %2 = nn.batch_norm(%1, %aten__batch_norm_0_weight, %aten__batch_norm_0_bias, %aten__batch_norm_0_mean, %aten__batch_norm_0_var) /* span=aten__batch_norm_0:0:0 */;\n",
      "  %3 = %2.0 /* span=aten__batch_norm_0:0:0 */;\n",
      "  nn.relu(%3) /* span=aten__relu_0:0:0 */\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import tvm\n",
    "from tvm import relay\n",
    "from torch import nn\n",
    "import torch\n",
    "\n",
    "class Model(nn.Module):\n",
    "    def __init__(self, *args, **kwargs) -> None:\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=True)\n",
    "        self.bn = nn.BatchNorm2d(16)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "def create_model(ishape = (1, 3, 4, 4)):\n",
    "    pt_model = Model().eval().float()\n",
    "    input_shapes = [(\"data\", ishape)]\n",
    "    # script_module = torch.jit.script(pt_model)\n",
    "    # mod, params = relay.frontend.from_pytorch(script_module, input_shapes)\n",
    "    idata = torch.rand(ishape).type(torch.float32)\n",
    "    traced_model = torch.jit.trace(pt_model, idata)\n",
    "    # traced_model 翻译为 TVM 前端模型\n",
    "    mod, params = relay.frontend.from_pytorch(traced_model, input_shapes, \n",
    "                                              use_parser_friendly_name=True)\n",
    "    return mod, params\n",
    "\n",
    "mod, params = create_model(ishape = (1, 3, 4, 4))\n",
    "print(mod[\"main\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参数绑定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "绑定参数前：\n",
      "fn (%data: Tensor[(1, 3, 4, 4), float32] /* span=aten___convolution_0_data:0:0 */, %aten___convolution_0_weight: Tensor[(16, 3, 3, 3), float32] /* span=aten___convolution_0_weight:0:0 */, %aten___convolution_0_bias: Tensor[(16), float32] /* span=aten___convolution_0_bias:0:0 */, %aten__batch_norm_0_weight: Tensor[(16), float32] /* span=aten__batch_norm_0_weight:0:0 */, %aten__batch_norm_0_bias: Tensor[(16), float32] /* span=aten__batch_norm_0_bias:0:0 */, %aten__batch_norm_0_mean: Tensor[(16), float32] /* span=aten__batch_norm_0_mean:0:0 */, %aten__batch_norm_0_var: Tensor[(16), float32] /* span=aten__batch_norm_0_var:0:0 */) {\n",
      "  %0 = nn.conv2d(%data, %aten___convolution_0_weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* span=aten___convolution_0:0:0 */;\n",
      "  %1 = nn.bias_add(%0, %aten___convolution_0_bias) /* span=aten___convolution_0:0:0 */;\n",
      "  %2 = nn.batch_norm(%1, %aten__batch_norm_0_weight, %aten__batch_norm_0_bias, %aten__batch_norm_0_mean, %aten__batch_norm_0_var) /* span=aten__batch_norm_0:0:0 */;\n",
      "  %3 = %2.0 /* span=aten__batch_norm_0:0:0 */;\n",
      "  nn.relu(%3) /* span=aten__relu_0:0:0 */\n",
      "}\n",
      "==================================================\n",
      "绑定参数后：\n",
      "fn (%data: Tensor[(1, 3, 4, 4), float32] /* span=aten___convolution_0_data:0:0 */) {\n",
      "  %0 = nn.conv2d(%data, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* span=aten___convolution_0:0:0 */;\n",
      "  %1 = nn.bias_add(%0, meta[relay.Constant][1]) /* span=aten___convolution_0:0:0 */;\n",
      "  %2 = nn.batch_norm(%1, meta[relay.Constant][2], meta[relay.Constant][3], meta[relay.Constant][4], meta[relay.Constant][5]) /* span=aten__batch_norm_0:0:0 */;\n",
      "  %3 = %2.0 /* span=aten__batch_norm_0:0:0 */;\n",
      "  nn.relu(%3) /* span=aten__relu_0:0:0 */\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tvm.relay.quantize.quantize import _bind_params\n",
    "\n",
    "print(f\"绑定参数前：\\n{mod['main']}\")\n",
    "print(\"=\"*50)\n",
    "mod[\"main\"] = _bind_params(mod[\"main\"], params)\n",
    "print(f\"绑定参数后：\\n{mod['main']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型简化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {\n",
      "  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
      "  %1 = add(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
      "  %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;\n",
      "  nn.relu(%2) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */\n",
      "} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */\n",
      "\n"
     ]
    }
   ],
   "source": [
    "optimize = tvm.transform.Sequential(\n",
    "    [\n",
    "        relay.transform.SimplifyInference(),\n",
    "        relay.transform.FoldConstant(),\n",
    "        relay.transform.FoldScaleAxis(),\n",
    "        relay.transform.CanonicalizeOps(),\n",
    "        relay.transform.FoldConstant(),\n",
    "    ]\n",
    ")\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    mod = optimize(mod)\n",
    "print(mod[\"main\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvm-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
